{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6eef0328-55f4-4c6f-a3d6-f0ae6f80cf9d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x75b187f4ac90>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a8a375b-4b61-4c77-8208-acaf3ba348b2",
   "metadata": {},
   "source": [
    "- 函数（function）基于变量，或者输出是否为多元\n",
    "    - 单变量，单输出\n",
    "    - 多变量（multi-variables），单输出\n",
    "        - 神经网络，loss function\n",
    "        - loss 关于最后一层（参数，weights/bias）的偏导数；\n",
    "    - 单变量，多输出\n",
    "    - 多变量，多输出\n",
    "        - Gradient：Jacobian matrix\n",
    "        - 最后一层关于倒数第二层（参数）的导数；\n",
    "- 其实这里也就是在复习多元函数微分学；\n",
    "- 多元函数（multivariables）微分通向矩阵分析；\n",
    "    - 至少需要涉及到 jacobian （matrix）的计算；"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2533f67f-0009-4517-bbb8-b411b967555e",
   "metadata": {},
   "source": [
    "## review `tensor.backward`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77cb1b94-79ac-44b4-9cb1-7b8a13d93baa",
   "metadata": {},
   "source": [
    "- for non-scalar outputs\n",
    "    - jacobian matrix "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5669b497-306e-4b58-bd97-fcfebd4a4496",
   "metadata": {},
   "source": [
    "\n",
    "$\\mathbf f: R^n\\rightarrow R^m, \\mathbf J\\in R^{m\\times n}$\n",
    "\n",
    "$$\n",
    "\\mathbf J = \\begin{bmatrix}\n",
    "  \\dfrac{\\partial \\mathbf{f}}{\\partial x_1} & \\cdots & \\dfrac{\\partial \\mathbf{f}}{\\partial x_n}\n",
    "\\end{bmatrix}\n",
    "= \\begin{bmatrix}\n",
    "  \\nabla^{\\mathrm T} f_1 \\\\  \n",
    "  \\vdots \\\\\n",
    "  \\nabla^{\\mathrm T} f_m   \n",
    "\\end{bmatrix}\n",
    "= \\begin{bmatrix}\n",
    "    \\dfrac{\\partial f_1}{\\partial x_1} & \\cdots & \\dfrac{\\partial f_1}{\\partial x_n}\\\\\n",
    "    \\vdots                             & \\ddots & \\vdots\\\\\n",
    "    \\dfrac{\\partial f_m}{\\partial x_1} & \\cdots & \\dfrac{\\partial f_m}{\\partial x_n}\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "09c07284-bd8c-49e3-ad45-0ed5ebbe9179",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.arange(3., requires_grad=True)\n",
    "y = x*x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac6aef26-75f9-4ef3-b502-8e2151329628",
   "metadata": {},
   "source": [
    "\n",
    "- $y_i=x_i^2$\n",
    "\n",
    "$$\n",
    "\\mathbf y=\\begin{bmatrix}\n",
    "y_1\\\\\n",
    "y_2\\\\\n",
    "y_3\\\\\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "x_1^2\\\\\n",
    "x_2^2\\\\\n",
    "x_3^2\n",
    "\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "\n",
    "$$\n",
    "\\mathbf J_y=\\begin{bmatrix}\n",
    "\\frac{\\partial y_1}{\\partial x_1} & \\frac{\\partial y_1}{\\partial x_2} & \\frac{\\partial y_1}{\\partial x_3}\\\\\n",
    "\\frac{\\partial y_2}{\\partial x_1} & \\frac{\\partial y_2}{\\partial x_2} & \\frac{\\partial y_2}{\\partial x_3}\\\\\n",
    "\\frac{\\partial y_3}{\\partial x_1} & \\frac{\\partial y_3}{\\partial x_2} & \\frac{\\partial y_3}{\\partial x_3}\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "2x_1 & 0 & 0\\\\\n",
    "0 & 2x_2 & 0\\\\\n",
    "0 & 0 & 2x_3\n",
    "\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "- 变量之间没有交叉耦合关系，因此 jacobian 是对角矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "16e06bb9-b64f-4724-b1b3-a497824e76cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# y.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50b096d0-2dd4-4883-81d0-6edfb29021df",
   "metadata": {},
   "source": [
    "$$\n",
    "\\mathbf J^T\\cdot \\mathbf v\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "13572c43-4e34-46fe-bf4c-0b041ec9e344",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 2., 4.])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.backward(gradient=torch.ones_like(x), retain_graph=True)\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3ee64b0a-13fb-4f04-8054-c108c2886cd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.,  4., 12.])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.grad.zero_()\n",
    "y.backward(gradient=torch.tensor([1., 2., 3.]), retain_graph=True)\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c2f609e-3ad5-4bd2-b667-ac1bded70e36",
   "metadata": {},
   "source": [
    "### multi input tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "64767728-f42e-4827-aba9-1c09e4b0f428",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.tensor([1., 2.], requires_grad=True)\n",
    "b = torch.tensor([2., 3.], requires_grad=True)\n",
    "Q = 3*a**3 - b**2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08975c47-8013-46d6-aaf7-6828ad5b0b2c",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&Q=\\begin{bmatrix}\n",
    "3a_1^3-b_1^2\\\\\n",
    "3a_2^3-b_2^2\n",
    "\\end{bmatrix}\\\\\n",
    "&\\mathbf J_Q = \\begin{bmatrix}\n",
    "\\frac{\\partial Q_1}{\\partial a_1} & \\frac{\\partial Q_1}{\\partial b_1} & \\frac{\\partial Q_1}{\\partial a_2} & \\frac{\\partial Q_1}{\\partial b_2}\\\\\n",
    "\\frac{\\partial Q_2}{\\partial a_1} & \\frac{\\partial Q_2}{\\partial b_1} & \\frac{\\partial Q_2}{\\partial a_2} & \\frac{\\partial Q_2}{\\partial b_2}\\\\\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "9a_1^2 & -2b_1 & 0 & 0\\\\\n",
    "0 & 0 & 9a_2^2 & -2b_2\n",
    "\\end{bmatrix}\\\\\n",
    "&\\mathbf J_q^Tv=\\begin{bmatrix}\n",
    "9a_1^2 & 0\\\\\n",
    "-2b_1 & 0\\\\\n",
    "0 & 9a_2^2\\\\\n",
    "0 & -2b_2\n",
    "\\end{bmatrix}\\cdot \\begin{bmatrix}\n",
    "1\\\\\n",
    "1\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "9a_1^2\\\\\n",
    "-2b_1\\\\\n",
    "9a_2^2\\\\\n",
    "-2b_2\n",
    "\\end{bmatrix}\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "53c8b525-150e-44eb-a8e0-65e8f93ac422",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 9., 36.])\n",
      "tensor([-4., -6.])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0., 0.])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q.backward(gradient=torch.ones(2, dtype=torch.float32), retain_graph=True)\n",
    "print(a.grad)\n",
    "print(b.grad)\n",
    "a.grad.zero_()\n",
    "b.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c69c7981-68d2-48f7-914e-2179b3ffd3ea",
   "metadata": {},
   "source": [
    "### detaching computation graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "44006a03-5712-4263-a068-f5171be534bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.arange(4., requires_grad=True)\n",
    "y = x*x\n",
    "u = y\n",
    "z = u*x\n",
    "z.sum().backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef4341af-0d27-4fd7-b5b6-18b0cf9f9165",
   "metadata": {},
   "source": [
    "$$\n",
    "z=u\\cdot x=y\\cdot x=(x\\cdot x)\\cdot x=x^3\n",
    "$$\n",
    "\n",
    "```\n",
    "x --> y --> u --> z\n",
    " \\               /\n",
    "  \\-------------/\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "1658d215-4234-43b3-b3df-23c0d72907ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.,  3., 12., 27.])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "9380b42f-f473-4c4a-8a2c-2bf67ceda5f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.arange(4., requires_grad=True)\n",
    "y = x*x\n",
    "u = y.detach()\n",
    "z = u*x\n",
    "z.sum().backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61f4ac15-b03b-474c-be89-80672ba53645",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "z=u\\cdot x\\\\\n",
    "\\frac{\\partial z}{\\partial x}=u\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "\n",
    "```\n",
    "x --> y  u --> z\n",
    " \\               /\n",
    "  \\-------------/\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1747223a-694e-4cbf-b8f1-4eedcfc3c19f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 4., 9.])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "1685ffb1-69af-4da7-bc06-20b621741fc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 4., 9.])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "c78ad971-06ab-4cd6-9366-c300903ad8af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bba9c673-ff5f-42db-9095-25a911d64899",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## `torch.autograd.grad`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab8dc24c-eae5-4d9f-97c1-80bbb592203d",
   "metadata": {},
   "source": [
    "### `grad()` 参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdda1e53-01ae-4eaa-b787-93a9f741b4b9",
   "metadata": {},
   "source": [
    "- 当 output 非标量时，需要指定 grad_outputs 参数；\n",
    "    - `grap_outputs.shape == outputs.shape`\n",
    "    - inputs, outputs, grad_outputs\n",
    "        - inputs: $x$\n",
    "        - outputs: $y$\n",
    "        - grad\n",
    "            $$\\frac{dy}{dx}$$\n",
    "        - grad_outputs $$\\frac{dL}{dy}$$\n",
    "    - with grad_outputs（VJP: **vector-jacobian** product）\n",
    "\n",
    "        $$\n",
    "        x.grad=\\frac{dL}{dy}\\frac{dy}{dx}\n",
    "        $$\n",
    "        \n",
    "- `retain_graph`:  True 则保留计算图， False则释放计算图\n",
    "- `create_graph`: 若要计算高阶导数，则必须选为True "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f39db0c-2068-47c0-be93-96a5f71ce510",
   "metadata": {},
   "source": [
    "### case1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0352d67f-94c5-4e50-ad00-43198ff9c07c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8823, 0.9150, 0.3829, 0.9593],\n",
       "        [0.3904, 0.6009, 0.2566, 0.7936],\n",
       "        [0.9408, 0.1332, 0.9346, 0.5936]], requires_grad=True)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(3, 4)\n",
    "# in_place\n",
    "x.requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fdd189be-90d2-4540-b254-5c76f7e135bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1.]]),)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.grad(x.sum(), x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63163899-633a-4445-9a9d-4733c0d6fc4e",
   "metadata": {},
   "source": [
    "### case2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5377ebd0-08c9-427d-afa0-2aa15a8a7c4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4., 6.])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([2.0, 3.0], requires_grad=True)\n",
    "\n",
    "# Define the multi-output function: y = [x0^2, x1^2]\n",
    "y = x ** 2\n",
    "torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ba5616fc-0551-4fc2-b489-3aa34a4bcea9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4., 0.],\n",
       "        [0., 6.]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.functional.jacobian(lambda x: x**2, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ba079888-3206-4b70-ba98-823eb3ddfc9a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 3.], requires_grad=True)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "64548829-254c-41ae-9fc3-8b1f66a3d73b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([7., 8.]),)\n",
      "(tensor([10., 22.]),)\n"
     ]
    }
   ],
   "source": [
    "y = torch.stack([x[0]**2, x[0]*x[1], x[1]**2])\n",
    "print(torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y), retain_graph=True))\n",
    "print(torch.autograd.grad(y, x, grad_outputs=torch.tensor([1., 2., 3.])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ed380197-e323-4b6c-9ddd-a3942009dada",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4., 0.],\n",
       "        [3., 2.],\n",
       "        [0., 6.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch.autograd.functional.jacobian(lambda x: torch.tensor([x[0]**2, x[0]*x[1], x[1]**2], requires_grad=True), x)\n",
    "torch.autograd.functional.jacobian(lambda x: torch.stack([x[0]**2, x[0]*x[1], x[1]**2]), x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08d3d80c-31e5-41cd-84dd-82cd878e51fd",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&f(\\mathbf x)=\\begin{bmatrix}\n",
    "x^2\\\\\n",
    "xy\\\\\n",
    "y^2\n",
    "\\end{bmatrix}\\\\\n",
    "&\\mathbf J_f=\\begin{bmatrix}\n",
    "2x & 0\\\\\n",
    "y & x\\\\\n",
    "0 & 2y\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "4 & 0\\\\\n",
    "3 & 2\\\\\n",
    "0 & 6\n",
    "\\end{bmatrix}\\\\\n",
    "&vjp=\\begin{bmatrix}1 & 2 & 3\\end{bmatrix}\\begin{bmatrix}\n",
    "4 & 0\\\\\n",
    "3 & 2\\\\\n",
    "0 & 6\n",
    "\\end{bmatrix}\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2690ca8-cd78-4d1f-a5d1-d78b0b3f7c14",
   "metadata": {},
   "source": [
    "### `flat_grad`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "3aefc938-00af-4c1e-bd27-3ff6adcb6024",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flat_grad(y, x, retain_graph=False, create_graph=False):\n",
    "    if create_graph:\n",
    "        retain_graph = True\n",
    "\n",
    "    g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)\n",
    "    g = torch.cat([t.view(-1) for t in g])\n",
    "    return g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "c2204f0b-bf88-4a91-b545-6556353bf850",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "    return (x**2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "3c82fe51-66be-4e5b-8c85-2351a8f79a51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 2., 3.], requires_grad=True)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.arange(4, dtype=torch.float32, requires_grad=True)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "b51e3f8b-6373-40e9-9c8e-7a257592cb2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 2., 4., 6.])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "flat_grad(f(x), x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "d8da910c-535e-40cf-b18f-6be072f9681f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 2., 4., 6.])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f(x).backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc1dd306-bddc-4485-870d-c65a154edeb3",
   "metadata": {},
   "source": [
    "## jacobian "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71df3ca3-fb61-45b3-95f2-90be70b5525e",
   "metadata": {},
   "source": [
    "$\\mathbf f: R^n\\rightarrow R^m, \\mathbf J\\in R^{m\\times n}$\n",
    "\n",
    "$$\n",
    "\\mathbf J = \\begin{bmatrix}\n",
    "  \\dfrac{\\partial \\mathbf{f}}{\\partial x_1} & \\cdots & \\dfrac{\\partial \\mathbf{f}}{\\partial x_n}\n",
    "\\end{bmatrix}\n",
    "= \\begin{bmatrix}\n",
    "  \\nabla^{\\mathrm T} f_1 \\\\  \n",
    "  \\vdots \\\\\n",
    "  \\nabla^{\\mathrm T} f_m   \n",
    "\\end{bmatrix}\n",
    "= \\begin{bmatrix}\n",
    "    \\dfrac{\\partial f_1}{\\partial x_1} & \\cdots & \\dfrac{\\partial f_1}{\\partial x_n}\\\\\n",
    "    \\vdots                             & \\ddots & \\vdots\\\\\n",
    "    \\dfrac{\\partial f_m}{\\partial x_1} & \\cdots & \\dfrac{\\partial f_m}{\\partial x_n}\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b031ad51-bbd2-499f-8603-d3b03f6a901f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):                                                                                             \n",
    "    return x * x * torch.arange(4, dtype=torch.float)         "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1645668a-4b0f-48d0-9c46-04fc4198d271",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "\\mathbf f(\\mathbf x)=\n",
    "\\begin{bmatrix}\n",
    "f_1(x_1)\\\\\n",
    "f_2(x_2)\\\\\n",
    "f_3(x_3)\\\\\n",
    "f_4(x_4)\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "0\\\\\n",
    "x_2^2\\\\\n",
    "2 x_3^2\\\\\n",
    "3 x_4^2\n",
    "\\end{bmatrix}\\\\\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "369f7ec2-22c7-465f-8561-46d48e957a1c",
   "metadata": {},
   "source": [
    "\n",
    "$$\n",
    "\\mathbf J=\\begin{bmatrix}\n",
    "\\dfrac{\\partial f_1}{\\partial x_1} & \\dfrac{\\partial f_1}{\\partial x_2} & \\dfrac{\\partial f_1}{\\partial x_3} & \\dfrac{\\partial f_1}{\\partial x_4}\\\\\n",
    "\\dfrac{\\partial f_2}{\\partial x_1} & \\dfrac{\\partial f_2}{\\partial x_2} & \\dfrac{\\partial f_2}{\\partial x_3} & \\dfrac{\\partial f_2}{\\partial x_4}\\\\\n",
    "\\dfrac{\\partial f_3}{\\partial x_1} & \\dfrac{\\partial f_3}{\\partial x_2} & \\dfrac{\\partial f_3}{\\partial x_3} & \\dfrac{\\partial f_3}{\\partial x_4}\\\\\n",
    "\\dfrac{\\partial f_4}{\\partial x_1} & \\dfrac{\\partial f_4}{\\partial x_2} & \\dfrac{\\partial f_4}{\\partial x_3} & \\dfrac{\\partial f_4}{\\partial x_4}\\\\\n",
    "\\end{bmatrix}=\\begin{bmatrix}\n",
    "0 & 0 & 0 & 0\\\\\n",
    "0 & 2x_2 & 0 & 0\\\\\n",
    "0 & 0 & 4x_3 & 0\\\\\n",
    "0 & 0 & 0 & 6x_4\\\\\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "86df6c0f-a3fe-408a-957e-73b92c450b19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def jacobian(y, x, create_graph=False):                                                               \n",
    "    jac = []                                                                                          \n",
    "    flat_y = y.reshape(-1)                                                                            \n",
    "    grad_y = torch.zeros_like(flat_y)                                                                 \n",
    "    for i in range(len(flat_y)):                                                                      \n",
    "        grad_y[i] = 1.                                                                                \n",
    "        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)\n",
    "        # print(flat_y, x, i, grad_x)\n",
    "        jac.append(grad_x.reshape(x.shape))                                                           \n",
    "        grad_y[i] = 0.                                                                                \n",
    "    return torch.stack(jac).reshape(y.shape + x.shape)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d3df6645-3f41-41f7-b19c-ee902b30f739",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1.], requires_grad=True)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.ones(4, requires_grad=True)     \n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "9c019807-68dc-431c-a56c-47acd40730e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 2., 3.], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b1b9b612-30e3-4970-86e2-f6b37c53b79c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f(x).reshape(-1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f75455c7-e911-4b1c-860c-3771e0e5ff73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 2., 0., 0.],\n",
       "        [0., 0., 4., 0.],\n",
       "        [0., 0., 0., 6.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "J = torch.autograd.functional.jacobian(f, x)\n",
    "J"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1b222af6-e425-46df-81d6-18a2a222fce8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 2., 0., 0.],\n",
       "        [0., 0., 4., 0.],\n",
       "        [0., 0., 0., 6.]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jacobian(f(x), x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c836ee4-3751-4525-82ee-f9f9b8b6c9c4",
   "metadata": {},
   "source": [
    "## hessian"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "740efcde-8d7a-47a5-94d5-9a2690203da4",
   "metadata": {},
   "source": [
    "$f : R^n \\to R$\n",
    "\n",
    "$$\n",
    "\\mathbf H_f=(\\mathbf H_f)_{i,j} = \\frac{\\partial^2 f}{\\partial x_i \\, \\partial x_j}= \\begin{bmatrix}\n",
    "  \\dfrac{\\partial^2 f}{\\partial x_1^2} & \\dfrac{\\partial^2 f}{\\partial x_1\\,\\partial x_2} & \\cdots & \\dfrac{\\partial^2 f}{\\partial x_1\\,\\partial x_n} \\\\[2.2ex]\n",
    "  \\dfrac{\\partial^2 f}{\\partial x_2\\,\\partial x_1} & \\dfrac{\\partial^2 f}{\\partial x_2^2} & \\cdots & \\dfrac{\\partial^2 f}{\\partial x_2\\,\\partial x_n} \\\\[2.2ex]\n",
    "  \\vdots & \\vdots & \\ddots & \\vdots \\\\[2.2ex]\n",
    "  \\dfrac{\\partial^2 f}{\\partial x_n\\,\\partial x_1} & \\dfrac{\\partial^2 f}{\\partial x_n\\,\\partial x_2} & \\cdots & \\dfrac{\\partial^2 f}{\\partial x_n^2}\n",
    "\\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2b285e02-f083-490c-945d-d782f17d55b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hessian(y, x):                                                                                    \n",
    "    return jacobian(jacobian(y, x, create_graph=True), x)                                             "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b6fe171d-2aa4-4dea-a3a2-ee56c93e46b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0.],\n",
       "         [0., 2., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 4., 0.],\n",
       "         [0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0.],\n",
       "         [0., 0., 0., 6.]]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hessian(f(x), x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "3ebf2251-c5d1-4e50-bd46-8253ba4300a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f_0(x):\n",
    "    return (x * x * torch.arange(4, dtype=torch.float32))[0]\n",
    "def f_1(x):\n",
    "    return (x * x * torch.arange(4, dtype=torch.float32))[1]\n",
    "def f_2(x):\n",
    "    return (x * x * torch.arange(4, dtype=torch.float32))[2]\n",
    "def f_3(x):\n",
    "    return (x * x * torch.arange(4, dtype=torch.float32))[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "fc629032-8d94-41d2-85de-3561460856b2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.]])\n",
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 2., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.]])\n",
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 4., 0.],\n",
      "        [0., 0., 0., 0.]])\n",
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 6.]])\n"
     ]
    }
   ],
   "source": [
    "print(torch.autograd.functional.hessian(f_0, x))\n",
    "print(torch.autograd.functional.hessian(f_1, x))\n",
    "print(torch.autograd.functional.hessian(f_2, x))\n",
    "print(torch.autograd.functional.hessian(f_3, x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "fee5f388-8997-47bf-b2ef-291acad6f651",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 6.]])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from functools import partial\n",
    "def f_i(x, i):\n",
    "    return (x*x*torch.arange(4, dtype=torch.float))[i]\n",
    "\n",
    "torch.autograd.functional.hessian(partial(f_i, i=3), x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
